package com.rpc.learn.regist;

import com.rpc.learn.loadbalance.RandomServiceLoadBalance;
import com.rpc.learn.loadbalance.ServiceLoadBalance;
import com.rpc.learn.transport.NettyRpcServer;
import org.apache.curator.RetryPolicy;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent;
import org.apache.curator.framework.recipes.cache.PathChildrenCacheListener;
import org.apache.curator.retry.ExponentialBackoffRetry;
import org.apache.zookeeper.CreateMode;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * zk 服务注册中心
 */
public class ZkServiceRegister implements ServiceRegister{
    private static CuratorFramework zkClient;
    private static String zkUrl = "127.0.0.1:2181";
    private static final String BASE_SERVICE_PATH = "/simpleRpc/service/";
    private static ServiceLoadBalance serviceLoadBalance;
    private static Map<String, List<String>> serviceAddressMap = new ConcurrentHashMap();
    static {
        // 从配置文件中获取zkurl，简单实现默认开发环境的地址
        RetryPolicy retryPolicy = new ExponentialBackoffRetry(1000, 3);
        zkClient = CuratorFrameworkFactory.newClient(zkUrl, retryPolicy);
        zkClient.start();
        // 从配置文件中获取负载均衡策略，简单实现为随机负载
        serviceLoadBalance = new RandomServiceLoadBalance();
    }

    @Override
    public void regist(String serviceName) {
        try {
            //  /simpleRpc/service/HelloService/127.0.0.1:12306
            String servicePath = BASE_SERVICE_PATH + serviceName + "/" + InetAddress.getLocalHost().getHostAddress() + ":" + NettyRpcServer.PORT;
            zkClient.create().creatingParentsIfNeeded().withMode(CreateMode.EPHEMERAL).forPath(servicePath);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 本地应该做一层缓存，并且监听zk 节点变化，有变化时接收到通知修改本地缓存
     * @param serviceName
     * @return
     */
    @Override
    public InetSocketAddress lookup(String serviceName) {
        //  /simpleRpc/service/HelloService/127.0.0.1:12306
        //  /simpleRpc/service/HelloService/127.0.0.1:12345
        // 比如有如上一个服务有两个实例，需要负载均衡
        String serviceDir = BASE_SERVICE_PATH + serviceName;
        try {
            List<String> services = getAllAddress(serviceDir);
            String address = serviceLoadBalance.selectService(services);
            String[] socketAddressArray = address.split(":");
            String host = socketAddressArray[0];
            int port = Integer.parseInt(socketAddressArray[1]);
            return new InetSocketAddress(host, port);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    private List<String> getAllAddress(String serviceDir) throws Exception {
        if(serviceAddressMap.containsKey(serviceDir)) {
            return serviceAddressMap.get(serviceDir);
        }
        //获取服务目录下的所有ip：port
        List<String> addressList = zkClient.getChildren().forPath(serviceDir);
        serviceAddressMap.put(serviceDir, addressList);
        registerWatcher(serviceDir);
        return addressList;
    }

    private void registerWatcher(String serviceDir) throws Exception {
        //监听该目录的节点变化，有变化时接收通知
        PathChildrenCache pathChildrenCache = new PathChildrenCache(zkClient, serviceDir, true);
        PathChildrenCacheListener cacheListener = new PathChildrenCacheListener() {
            @Override
            public void childEvent(CuratorFramework client, PathChildrenCacheEvent event) throws Exception {
                List<String> list = client.getChildren().forPath(serviceDir);
                serviceAddressMap.put(serviceDir, list);
            }
        };
        pathChildrenCache.getListenable().addListener(cacheListener);
        pathChildrenCache.start();
    }
}
